from visual import *

from Tracer import *
from Spin import *


 
class VisualSphere:
    def __init__(self, nSpinsX, nSpinsY, stripeSpacingX, stripeSpacingY, tracerLength, k):
        self.window = display(title='Spin Sphere', width=420, height=420, x=40, y=440, center=(0,0,0), background=(.0,.0,.0))
        self.window.fov = pi/900000.0
        self.window.lights = [vector(.5,.5,.5), vector(-.5,-.5,-.5)]
        self.window.up = (0.0,0.001,1.0)
        self.window.forward = (-0.5,1.0,-0.8)
        self.window.center = (0.0,0.0,0.0)
        self.window.range = 2.0
        self.window.autoscale = 0
        self.window.visible = 1
        self.window.select()

        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY

        self.stripeSpacingX = stripeSpacingX
        self.stripeSpacingY = stripeSpacingY
        
        self.k = k
        
        self.tracerLength = tracerLength
        
        self.tvVectorsPhaseColor = 0
        self.tvSWAxis = 0

        
        
        self.spinArray = [ None ] * nSpinsX                               # a vector of null elements, nSpinsX long
        for x in range(nSpinsX):
            self.spinArray[x] = [0] * nSpinsY                             # a vector of zeros nSpinsY long
            for y in range(nSpinsY):
                self.spinArray[x][y] = Spin()




        self.ucX = self.nSpinsX/self.stripeSpacingX
        self.ucY = self.nSpinsY/self.stripeSpacingY


##        self.ucSpinArray = [ None ] * self.ucX
##        for x in range(self.ucX):
##            self.ucSpinArray[x] = [0] * self.ucY
##            for y in range(self.ucY):
##                self.ucSpinArray[x][y] = Spin()



        self.tracerArray = [ None ] * nSpinsX                               # a vector of null elements, nSpinsX long
        for x in range(nSpinsX):
            self.tracerArray[x] = [0] * nSpinsY                             # a vector of zeros nSpinsY long
            for y in range(nSpinsY):
                self.tracerArray[x][y] = Tracer(self.tracerLength)

##        self.sumTracer = Tracer(tracerLength)
##
##        self.swAxis = []  
##        self.swAxisLine = curve(color=(.2,.5,.95), visible=self.tvSWAxis)
##        self.swAxis2 = []  
##        self.swAxisLine2 = curve(color=(.8,.2,.3), visible=self.tvSWAxis)
        
        self.equatorLine = curve(x=cos(arange(0,2.1*pi,.1565)), y=sin(arange(0,2.1*pi,.1565)), color=(.5,.5,.5))

        self.xAxisLine = curve(pos=[(-1.4,-1.3,0.0),(-0.8,-1.3,0.0)], color=(.9,.9,.9))
        self.yAxisLine = curve(pos=[(-1.3,-1.4,0.0),(-1.3,-0.8,0.0)], color=(.9,.9,.9))
        self.zAxisLine = curve(pos=[(-1.3,-1.3,-0.1),(-1.3,-1.3,0.5)], color=(.9,.9,.9))

        self.xAxisLabel = label(pos=vector(-0.8,-1.3,0.0), text='x', xoffset=0, yoffset=-.001, space=.1, height=10, box=0, opacity=0, line=0, border=4, color=(.9,.9,.9), visible=1)
        self.yAxisLabel = label(pos=vector(-1.3,-0.8,0.0), text='y', xoffset=-.001, yoffset=.001, space=.1, height=10, box=0, opacity=0, line=0, border=4, color=(.9,.9,.9), visible=1)
        self.zAxisLabel = label(pos=vector(-1.3,-1.3,0.5), text='z', xoffset=.001, yoffset=.001, space=.1, height=10, box=0, opacity=0, line=0, border=4, color=(.9,.9,.9), visible=1)

##        self.addPoints()
        self.addVectors()




    ## TRACER METHODS

    def updateTracers(self, nLattice):
        energyArray = nLattice.returnEnergies()

        phiScale = nLattice.returnPhiScale()+.001
        phiArray = nLattice.returnPhiArray()
        currentState = nLattice.returnState()
        zArray = nLattice.returnzArray()
        
        location = (0.0,0.0,0.0)

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY

                zFlip = zArray[xModXSpacing][yModYSpacing]

                new = currentState[x][y]
                k1 = (0.0,0.0,1.0)
                z = zFlip*new[0][2]/(phiScale*phiArray[xModXSpacing][yModYSpacing])
                color = ((1.0+z)/2.0,3*(1.0-(z)**2),(1-z)/2.0)

                location = self.spinArray[x][y].getVectorAxis()
                self.tracerArray[x][y].step(location, color)
                
##        color = (1.0,1.0,1.0)
##        sumPosition = self.spinArray[0][0].getVectorAxis() + self.spinArray[1][0].getVectorAxis() + self.spinArray[2][0].getVectorAxis()
##        self.sumTracer.step(sumPosition, color)

            
    def resetTracers(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.tracerArray[x][y].delTracer()
                
    def toggleTracers(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.tracerArray[x][y].toggleVisibility()



    ## VECTOR METHODS

    def addVectors(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                arrowPos = (0.0,0.0,0.0)
                z=.7
                color = ((1+z)/2,0.0,(1-z)/2)
                width = 0.025

                self.spinArray[x][y].setArrow(arrowPos, color, width)


##        for x in range(self.ucX):
##            for y in range(self.ucY):
##                arrowPos = (0.0,0.0,0.0)
##                z=.7
##                color = ((1+z)/2,0.0,(1-z)/2)
##                width = 0.0025
##                                                                                              
##                self.ucSpinArray[x][y].setArrow(arrowPos, color, width)




    def updateVectors(self, nLattice):
        energyArray = nLattice.returnEnergies()

        phiScale = nLattice.returnPhiScale()+.001
        phiArray = nLattice.returnPhiArray()
        currentState = nLattice.returnState()
        zArray = nLattice.returnzArray()
     
        arrowPos = (0.0,0.0,0.0)

        for xC in range(self.ucX):
            for yC in range(self.ucY):
                for xS in range(self.stripeSpacingX):
                    for yS in range(self.stripeSpacingY):
                        x = xS + xC*self.stripeSpacingX
                        y = yS + yC*self.stripeSpacingY
                        
                        xModXSpacing = x%self.stripeSpacingX
                        yModYSpacing = y%self.stripeSpacingY
                      
                        self.spinArray[x][y].updateVectorAxis(nLattice.returnState()[x][y][0])

                        zFlip = zArray[xModXSpacing][yModYSpacing]

                        new = currentState[x][y]
                        k1 = norm(self.k)
                        z = zFlip*dot(new, k1)[0]/(phiScale*phiArray[xModXSpacing][yModYSpacing])
                        color = ((1.0+z)/2.0,1.0-(z)**2,(1.0-z)/2.0)
                        
                        width = 0.025
                        self.spinArray[x][y].setArrow(arrowPos, color, width)
                        arrowPos = vector(0,0,0)

##        arrowPos = (0.0,0.0,0.0)


##        for xC in range(self.ucX):
##            for yC in range(self.ucY):
##
##                spinUCsum = (0.0,0.0,0.0)
##
##                for xS in range(self.stripeSpacingX):
##                    for yS in range(self.stripeSpacingY):
##                        spinUCsum = nLattice.returnState()[xS + xC*self.stripeSpacingX][yS + yC*self.stripeSpacingY][0] + spinUCsum
##
##                self.ucSpinArray[xC][yC].updateVectorAxis(spinUCsum)
##                
##                color = (0.9,0.8,0.9)
##                width = 0.0025
##                                                                                              
##                self.ucSpinArray[xC][yC].setArrow(arrowPos, color, width)
##                arrowPos = spinUCsum + arrowPos
                                                
                                
              

    def toggleVectors(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.spinArray[x][y].toggleVectorVisibility()

##        for x in range(self.ucX):
##            for y in range(self.ucY):
##                self.ucSpinArray[x][y].toggleVectorVisibility()
                



    ## POINT METHODS

##    def addPoints(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                colorArg = 0.7
##                color = (colorArg,0.0,1.0-colorArg)
##                radius = 0.03
##                self.spinArray[x][y].setPoint(color, radius)
##
##    def updatePoints(self, nLattice):
##        energyArray = nLattice.returnEnergies() 
##
##        pointPos = (0.0,0.0,0.0)
##
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                colorArg = energyArray[x][y]
##                color = (colorArg,0.0,1.0-colorArg)
##                self.spinArray[x][y].updatePointPos(pointPos)
##                self.spinArray[x][y].updatePointColor(color)
##                pointPos = nLattice.returnState()[(x)%self.nSpinsX][(y)%self.nSpinsY][0] + pointPos
##
##    def togglePoints(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.spinArray[x][y].togglePointVisibility()





    ## TORQUE VECTOR METHODS
##
##    def addTorqueVectors(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                arrowPos = nLattice.returnTorqueArray()[x][y][0]
##                color=(.5,.5,1)
##                width=0.02
##                self.spinArray[x][y].setTorqueArrow(arrowPos, color, width)
##
##
##    def updateTorqueVectors(self, nLattice):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.spinArray[x][y].updateTorqueVectorPos(nLattice.returnState()[x][y][0])
##                self.spinArray[x][y].updateTorqueVectorAxis(nLattice.returnTorqueArray()[x][y][0])
##
##
##    def toggleTorqueVectors(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.spinArray[x][y].toggleTorqueVectorVisibility()

















    ## AXIS METHODS
 
##    def updateAxis(self):
####        self.axis = self.numerics.returnSpinSum()
##        self.swAxis = [0.0*self.axis,self.axis]
##        self.swAxisLine.pos = self.swAxis
##        self.swAxisLine2.pos = self.swAxis2
##
##    def toggleAxis(self):
##        self.tvSWAxis = (self.tvSWAxis+1)%2
##        self.swAxisLine.visible = self.tvSWAxis
##        self.swAxisLine2.visible = self.tvSWAxis
##
##    def updateViewAxis(self):
####        self.axis = self.numerics.returnSpinSum()
##        self.window.forward = norm(self.axis+vector(0.0,0.0,0.000001))


    ## OTHER, STATE

##    def updateState(self):
##        self.spinArray = self.numerics.returnState()
##        self.meanFieldArray = self.numerics.returnMeanFieldArray()

    def returnArray(self):
        return self.spinArray







##        self.drag = 0
##        self.obs = vector(0.0,0.0,0.0)      
##            if self.paused == 1 or self.dt == 0.0:
##                if self.vSphere.window.mouse.events:
##                    m = self.vSphere.window.mouse.getevent()
##                    if m.click:
##                        newpos = self.vSphere.window.mouse.project(normal = self.vSphere.window.forward, d=1)
##                        spherePoint = norm(newpos-self.vSphere.window.forward)
##
##                        self.nLattice.moveSpin(vector(spherePoint))


##    def changeTracerLength(self, tracerLength):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.tracerArray[x][y].tracer.


##    def resetVisuals(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.spinArray[x][y].tracer.visible = 0
##                self.spinArray[x][y].point.visible = 0
##                self.spinArray[x][y].vector.visible = 0


